{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6adfa9fa-b760-49d8-be45-65fb67c5ab48",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# CodeLlama-34b seq-scheduler rollingbatch deployment guide\n",
    "In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "* S3 bucket push access\n",
    "* SageMaker access\n",
    "## Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59dcc3aa-2cf6-44fc-95d8-d3fc819b5593",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e45b3aab-7136-4c94-8c60-20a40191d08f",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker import Model, image_uris, serializers, deserializers\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "model_bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "s3_code_prefix = \"hf-large-model-djl/code_llama2\"  # folder within bucket where code artifact will go\n",
    "\n",
    "region = sess._region_name  # region name of the current SageMaker Studio environment\n",
    "account_id = sess.account_id()  # account_id of the current SageMaker Studio environment\n",
    "\n",
    "s3_client = boto3.client(\"s3\")\n",
    "sm_client = boto3.client(\"sagemaker\")\n",
    "smr_client = boto3.client(\"sagemaker-runtime\")\n",
    "\n",
    "jinja_env = jinja2.Environment()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b0f3f23-33ef-4a39-98fc-cbe03e30fcd6",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 2: Start preparing model artifacts\n",
    "In LMI contianer, we expect some artifacts to help setting up the model\n",
    "\n",
    "* serving.properties (required): Defines the model server settings\n",
    "* model.py (optional): A python file to define the core inference logic\n",
    "* requirements.txt (optional): Any additional pip wheel need to install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b825f1b5-0d86-4f6b-91db-8ecfb361066c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['MODEL_ID'] = \"TheBloke/Llama-2-70B-fp16\"\n",
    "os.environ['HF_TRUST_REMOTE_CODE'] = \"TRUE\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d42ce3c-fb40-4869-ab49-4eaa881b5ff9",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_id = os.getenv('MODEL_ID')\n",
    "with open('serving.properties', 'w') as f:\n",
    "    f.write(f\"\"\"engine=Python\n",
    "option.model_id={model_id}\n",
    "option.tensor_parallel_degree=8\n",
    "option.dtype=fp16\n",
    "option.model_loading_timeout=3600\n",
    "option.trust_remote_code=true\n",
    "\n",
    "# rolling-batch parameters\n",
    "option.max_rolling_batch_size=4\n",
    "option.rolling_batch=scheduler\n",
    "\n",
    "# seq-scheduler parameters\n",
    "# limits the max_sparsity in the token sequence caused by padding\n",
    "option.max_sparsity=0.33\n",
    "# limits the max number of batch splits, where each split has its own inference call\n",
    "option.max_splits=3\n",
    "# other options: contrastive, sample\n",
    "option.decoding_strategy=greedy\n",
    "# default: true\n",
    "option.disable_flash_attn=true\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed9512a3-9603-4b39-ae92-45481b397c00",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%%sh\n",
    "mkdir mymodel\n",
    "mv serving.properties mymodel/\n",
    "tar czvf mymodel-3-code-llama.tar.gz mymodel/\n",
    "rm -rf mymodel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e6e5f3b-0450-4365-b67b-d2a9f2cfcac4",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 3: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch\n",
    "\n",
    "### Getting the container image URI\n",
    "\n",
    "[Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "765de906-9747-4d69-aa78-0b0a359bd57f",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "image_uri = image_uris.retrieve(\n",
    "        framework=\"djl-deepspeed\",\n",
    "        region=sess.boto_session.region_name,\n",
    "        version=\"0.27.0\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbd9f984-fc41-47e9-b17e-3b96fb75f49e",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Upload artifact on S3 and create SageMaker model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53bea8c5-8fbf-463b-b430-87d6c32b8806",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "s3_code_prefix = \"large-model-lmi/code\"\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "code_artifact = sess.upload_data(\"mymodel-3-code-llama.tar.gz\", bucket, s3_code_prefix)\n",
    "print(f\"S3 Code or Model tar ball uploaded to --- > {code_artifact}\")\n",
    "\n",
    "model = Model(image_uri=image_uri, model_data=code_artifact, role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "787cf1c3-ea9f-4a83-84d6-480bbb07f57b",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 4: Create SageMaker endpoint\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08b153e2-3eb6-4882-aeb5-81d7efbb7a9a",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------"
     ]
    }
   ],
   "source": [
    "instance_type = \"ml.g5.12xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model-3-code-llama\")\n",
    "\n",
    "model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "             container_startup_health_check_timeout=3600\n",
    "             )\n",
    "\n",
    "# our requests and responses will be in json format so we specify the serializer and the deserializer\n",
    "predictor = sagemaker.Predictor(\n",
    "    endpoint_name=endpoint_name,\n",
    "    sagemaker_session=sess,\n",
    "    serializer=serializers.JSONSerializer(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5e60c73-f3d8-43ef-9261-23677f03d5cb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 5: Test and benchmark the inference\n",
    "Firstly let's try to run with a wrong inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e21c0102-132c-4c18-8627-223592578c86",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"inputs\": \"def hello_world():\", \"parameters\": {\"max_new_tokens\":128, \"do_sample\":\"true\"}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa8fba6e-aa0b-44b1-82d5-9d369e38e8bb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "790bf3f4-9398-41dd-8648-b6e43053b3e7",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "This can be done outside this notebook, in a bash shell terminal. The connection to the server is via the $SAGEMAKER url. The `awscurl` here is a benchmark tool, obtainable from \n",
    "```\n",
    "wget https://github.com/frankfliu/junkyard/releases/download/v0.3.1/awscurl && chmod +x awscurl\n",
    "```\n",
    ". It can be replaced with normal curl command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ced57091-9962-4f6b-9fc5-6f07207d4867",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%%sh\n",
    "export CONCUR=4\n",
    "export SAGEMAKER=https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/lmi-model-2-code-llama-2023-11-11-03-17-14-364/invocations\n",
    "TOKENIZER=codellama/CodeLlama-34b-hf ./awscurl -c $CONCUR -N 10 -n sagemaker $SAGEMAKER \\\n",
    "  --connect-timeout 660 \\\n",
    "  -H \"Content-type: application/json\" \\\n",
    "  -d '{\"inputs\":\"The new movie that got Oscar this year\",\"parameters\":{\"max_new_tokens\":256, \"do_sample\":true, \"temperature\":0.8, \"top_k\":5}}' \\\n",
    "  -t -o output-3-$CONCUR.txt; \\\n",
    "mv output-3-$CONCUR.txt.0 /home/ec2-user/SageMaker/output_code-llama"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f69dbef-d858-46ed-a163-0a15c65056ca",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Clean up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d56c6f57-4554-4ea1-ad58-f0a12e294d44",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "sess.delete_endpoint(endpoint_name)\n",
    "sess.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p310",
   "language": "python",
   "name": "conda_pytorch_p310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}